In [ ]:
# Copyright 2020 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Deep Learning Design Patterns - Code Labs

Lab Exercise #7 - Get Familiar with Alternative Connectivity Patterns

Prerequistes:

1. Familiar with Python
2. Completed Chapter 3: Alternative Connectivity Patterns

Objectives:

1. Code a Squeeze-Excitation (SE) Link
2. Add the SE-Link to an Inception V1 block.
3. Construct a mini SE-Inception V1 model
4. Add the SE-Link to a DenseNet block.
5. Construct a mini SE-DenseNet model.

Let's now code a squeeze and excitation (SE) link. Remember, uses dense layers with non-linear activation functions.

You will need to:

1. Determine the number of input channels to the (se) link.
2. Set the number of filters for the squeeze.
3. multiply the output of the excitation with the input to the link.
4. Set the ratio for the squeeze.

In [ ]:
from tensorflow.keras import Input, Model
from tensorflow.keras.layers import Conv2D, ReLU, BatchNormalization, Concatenate, Dense
from tensorflow.keras.layers import GlobalAveragePooling2D, Reshape, Multiply

def se_link(inputs, **params): 
    ratio = params['ratio']
    
    # Get the number of input channels (feature maps)
    # HINT: It is always the last dimension
    n_filters = inputs.shape[??]
    
    # Global Average Pooling to 1x1xC
    outputs = GlobalAveragePooling2D()(inputs)
    outputs = Reshape((1, 1, n_filters))(outputs)
    
    # Dense w/ReLU to 1x1xC/ratio
    # HINT: divide the input number of filters by the ratio, and use the same activation as you do with Conv2D.
    outputs = Dense(n_filters // ??, activation='??')(outputs)
    
    # Dense w/Sigmoid to 1x1xC
    outputs = Dense(n_filters, activation='sigmoid')(outputs)
    
    # Multiply the SE output with the inputs
    # HINT: The output from the dense layer with the sigmoid activation
    outputs = Multiply()([inputs, ??])
    
    return outputs

inputs = Input((32, 32, 3))
# Set the squeeze ratio to 4
# HINT: It's the number 4.
outputs = se_link(inputs, **{'ratio': ??})
model = Model(inputs, outputs)

Verify the module using summary method

It should look like below.

Model: "model_6"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_19 (InputLayer)           [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
global_average_pooling2d_16 (Gl (None, 3)            0           input_19[0][0]                   
__________________________________________________________________________________________________
reshape_2 (Reshape)             (None, 1, 1, 3)      0           global_average_pooling2d_16[0][0]
__________________________________________________________________________________________________
dense_22 (Dense)                (None, 1, 1, 0)      0           reshape_2[0][0]                  
__________________________________________________________________________________________________
dense_23 (Dense)                (None, 1, 1, 3)      3           dense_22[0][0]                   
__________________________________________________________________________________________________
multiply_4 (Multiply)           (None, 32, 32, 3)    0           input_19[0][0]                   
                                                                 dense_23[0][0]                   
==================================================================================================
Total params: 3
Trainable params: 3

In [ ]:
model.summary()

Let's construct a mini-inception V1 model with a SE-Link. We will call it a SE-Inception V1. Our model will consist of:

1. single se-inception block
2. classifier

You will need to:

1. Complete the parameter arguments to the se_link() call.

In [ ]:
def se_inception_block(inputs, **params):
    # pooling branch
    x1 = MaxPooling2D((3, 3), strides=(1, 1), padding='same')(inputs)
    x1 = Conv2D(64, (1, 1), strides=(1, 1), padding='same')(x1)
    
    # 1x1 branch
    x2 = Conv2D(64, (1, 1), strides=(1, 1), padding='same', activation='relu')(inputs)
    
    # 3x3 branch
    x3 = Conv2D(64, (1, 1), strides=(1, 1), padding='same')(inputs)
    x3 = Conv2D(96, (3, 3), strides=(1, 1), padding='same', activation='relu')(x3)
    
    # 5x5 branch
    x4 = Conv2D(64, (1, 1), strides=(1, 1), padding='same')(inputs)
    x4 = Conv2D(48, (5, 5), strides=(1, 1), padding='same', activation='relu')(x4)
    
    outputs = Concatenate()([x1, x2, x3, x4])
    
    # Insert the SE-Link 
    # HINT: the input to the SE-Link is the output from the concatenation. Use **kwargs syntax to pass down the ratio
    outputs = se_link(??)
    return outputs

def classifier(inputs, n_classes):
    outputs = GlobalAveragePooling2D()(inputs)
    outputs = Dense(n_classes, activation='softmax')(outputs)
    return outputs
    
inputs = Input((32, 32, 3))
outputs = se_inception_block(inputs, **{'ratio': 4})
outputs = classifier(outputs, 10)
model = Model(inputs, outputs)

Verify the module using summary method

It should look like below.

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_10 (InputLayer)           [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
max_pooling2d_3 (MaxPooling2D)  (None, 32, 32, 3)    0           input_10[0][0]                   
__________________________________________________________________________________________________
conv2d_20 (Conv2D)              (None, 32, 32, 64)   256         input_10[0][0]                   
__________________________________________________________________________________________________
conv2d_22 (Conv2D)              (None, 32, 32, 64)   256         input_10[0][0]                   
__________________________________________________________________________________________________
conv2d_18 (Conv2D)              (None, 32, 32, 64)   256         max_pooling2d_3[0][0]            
__________________________________________________________________________________________________
conv2d_19 (Conv2D)              (None, 32, 32, 64)   256         input_10[0][0]                   
__________________________________________________________________________________________________
conv2d_21 (Conv2D)              (None, 32, 32, 96)   55392       conv2d_20[0][0]                  
__________________________________________________________________________________________________
conv2d_23 (Conv2D)              (None, 32, 32, 48)   76848       conv2d_22[0][0]                  
__________________________________________________________________________________________________
concatenate_2 (Concatenate)     (None, 32, 32, 272)  0           conv2d_18[0][0]                  
                                                                 conv2d_19[0][0]                  
                                                                 conv2d_21[0][0]                  
                                                                 conv2d_23[0][0]                  
__________________________________________________________________________________________________
global_average_pooling2d_5 (Glo (None, 272)          0           concatenate_2[0][0]              
__________________________________________________________________________________________________
dense_6 (Dense)                 (None, 4)            1092        global_average_pooling2d_5[0][0] 
__________________________________________________________________________________________________
dense_7 (Dense)                 (None, 16)           80          dense_6[0][0]                    
==================================================================================================
Total params: 134,436
Trainable params: 134,436
Non-trainable params: 0

In [ ]:
model.summary()

Training

Let's do a bit of training with your mini-inception V3 + SE-Link (SE-InceptionV1)

Dataset

Let's get the tf.Keras builtin dataset for CIFAR-10. These are 32x32 color images (3 channels) of 10 classes (airplanes, cars, birds, cats, deer, dogs, frogs, horses, ships, and trucks). We will preprocess the image data (not covered yet).


In [ ]:
from tensorflow.keras.datasets import cifar10
import numpy as np

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
x_train = (x_train / 255.0).astype(np.float32)
x_test  = (x_test / 255.0).astype(np.float32)

Results

Let's train the model for 3 epochs.

Because it just a few epochs, you test accuracy may vary from run to run. For me, it was 40.5%


In [ ]:
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['acc'])
model.fit(x_train, y_train, epochs=3, batch_size=32, validation_split=0.1, verbose=1)
model.evaluate(x_test, y_test)

Add another SE-Inception V1 block

Let's improve our model by adding a second SE-InceptionV1 block, but we will change the compression ratio in the SE-Link.

You will need to:

1. Complete the parameter arguments for the second call to the se_inception_block() method.

In [ ]:
inputs = Input((32, 32, 3))
outputs = se_inception_block(inputs, **{'ratio': 4})
# Add the second se-inception block with ratio = 2
# HINT: use key-value syntax
outputs = se_inception_block(outputs, **{??})
outputs = classifier(outputs, 10)
model = Model(inputs, outputs)

Results

Let's train the model for just 2 epochs.

By going deeper (adding an inception block with se-link), we get higher accuracy in 2 epochs vs. 3 epochs in the more shallower version. For me, it was 51.6%


In [ ]:
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['acc'])
model.fit(x_train, y_train, epochs=2, batch_size=32, validation_split=0.1, verbose=1)
model.evaluate(x_test, y_test)

Let's construct a mini-densenet model with a SE-Link. We will call it a SE-DenseNet. Our model will consist of:

1. se-densenet group with two blocks
2. classifier

You will need to:

1. Extract the block and ratio settings.
2. Set the linear projection convolution to be a 1x1 kernel
3. Pass the block-level parameters to the se_dense_block() calls.
4. Pass the group level parameter ratio to the group() call.

In [ ]:
def group(inputs, **params):
    # Get the blocks and ratio data
    # HINT: they are keys in params
    blocks = ??
    ratio  = ??
    
    outputs = inputs
    for block_params in blocks:
        # Construct the next SE-DenseNet
        # HINT: pass the parameters for this block as kwargs
        outputs = se_dense_block(outputs, **??, ratio=ratio)
    return outputs
        
def se_dense_block(inputs, **params):
    n_filters = params['n_filters']
    
    # 1x1 linear projection (dimensionality expansion)
    outputs = BatchNormalization()(inputs)
    outputs = ReLU()(outputs)
    # HINT: its a 1x1 kernel
    outputs = Conv2D(n_filters * 4, ??, strides=(1, 1), padding='same')(outputs)
    
    # Bottleneck (dimensionality restoration)
    outputs = BatchNormalization()(outputs)
    outputs = ReLU()(outputs)
    outputs = Conv2D(n_filters, (3, 3), strides=(1, 1), padding='same')(outputs)
    
    # Add the SE-Link before the concatentation of the inputs to the outputs
    # HINT: pass thru the kwargs that were passed into the block method
    outputs = se_link(outputs, **??)
    
    # feature reuse
    outputs = Concatenate()([inputs, outputs])
    return outputs

inputs = Input((32, 32, 3))
# Set the squeeze ratio to 4
# HINT: pass it as kwargs ratio and value 4 -- that is a key/value pair
outputs = group(inputs, **{ 'blocks': [{'n_filters': 32}, {'n_filters': 64}], ?? })
outputs = classifier(outputs, 10)

Verify the module using summary method

It should look like below.

Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_29 (InputLayer)           [(None, 32, 32, 3)]  0                                            
__________________________________________________________________________________________________
batch_normalization_16 (BatchNo (None, 32, 32, 3)    12          input_29[0][0]                   
__________________________________________________________________________________________________
re_lu_16 (ReLU)                 (None, 32, 32, 3)    0           batch_normalization_16[0][0]     
__________________________________________________________________________________________________
conv2d_88 (Conv2D)              (None, 32, 32, 128)  512         re_lu_16[0][0]                   
__________________________________________________________________________________________________
batch_normalization_17 (BatchNo (None, 32, 32, 128)  512         conv2d_88[0][0]                  
__________________________________________________________________________________________________
re_lu_17 (ReLU)                 (None, 32, 32, 128)  0           batch_normalization_17[0][0]     
__________________________________________________________________________________________________
conv2d_89 (Conv2D)              (None, 32, 32, 32)   36896       re_lu_17[0][0]                   
__________________________________________________________________________________________________
global_average_pooling2d_28 (Gl (None, 32)           0           conv2d_89[0][0]                  
__________________________________________________________________________________________________
reshape_8 (Reshape)             (None, 1, 1, 32)     0           global_average_pooling2d_28[0][0]
__________________________________________________________________________________________________
dense_39 (Dense)                (None, 1, 1, 8)      264         reshape_8[0][0]                  
__________________________________________________________________________________________________
dense_40 (Dense)                (None, 1, 1, 32)     288         dense_39[0][0]                   
__________________________________________________________________________________________________
multiply_10 (Multiply)          (None, 32, 32, 32)   0           conv2d_89[0][0]                  
                                                                 dense_40[0][0]                   
__________________________________________________________________________________________________
concatenate_17 (Concatenate)    (None, 32, 32, 35)   0           input_29[0][0]                   
                                                                 multiply_10[0][0]                
__________________________________________________________________________________________________
batch_normalization_18 (BatchNo (None, 32, 32, 35)   140         concatenate_17[0][0]             
__________________________________________________________________________________________________
re_lu_18 (ReLU)                 (None, 32, 32, 35)   0           batch_normalization_18[0][0]     
__________________________________________________________________________________________________
conv2d_90 (Conv2D)              (None, 32, 32, 256)  9216        re_lu_18[0][0]                   
__________________________________________________________________________________________________
batch_normalization_19 (BatchNo (None, 32, 32, 256)  1024        conv2d_90[0][0]                  
__________________________________________________________________________________________________
re_lu_19 (ReLU)                 (None, 32, 32, 256)  0           batch_normalization_19[0][0]     
__________________________________________________________________________________________________
conv2d_91 (Conv2D)              (None, 32, 32, 64)   147520      re_lu_19[0][0]                   
__________________________________________________________________________________________________
global_average_pooling2d_29 (Gl (None, 64)           0           conv2d_91[0][0]                  
__________________________________________________________________________________________________
reshape_9 (Reshape)             (None, 1, 1, 64)     0           global_average_pooling2d_29[0][0]
__________________________________________________________________________________________________
dense_41 (Dense)                (None, 1, 1, 16)     1040        reshape_9[0][0]                  
__________________________________________________________________________________________________
dense_42 (Dense)                (None, 1, 1, 64)     1088        dense_41[0][0]                   
__________________________________________________________________________________________________
multiply_11 (Multiply)          (None, 32, 32, 64)   0           conv2d_91[0][0]                  
                                                                 dense_42[0][0]                   
__________________________________________________________________________________________________
concatenate_18 (Concatenate)    (None, 32, 32, 99)   0           concatenate_17[0][0]             
                                                                 multiply_11[0][0]                
__________________________________________________________________________________________________
global_average_pooling2d_30 (Gl (None, 99)           0           concatenate_18[0][0]             
__________________________________________________________________________________________________
dense_43 (Dense)                (None, 10)           1000        global_average_pooling2d_30[0][0]
==================================================================================================
Total params: 199,512
Trainable params: 198,668
Non-trainable params: 844

In [ ]:
model.summary()

Results

Let's train the model for just 2 epochs.

For me, it was 42.2%


In [ ]:
model.compile(loss='sparse_categorical_crossentropy', optimizer='adam', metrics=['acc'])
model.fit(x_train, y_train, epochs=2, batch_size=32, validation_split=0.1, verbose=1)
model.evaluate(x_test, y_test)

End of Lab Exercise